Skip to content

Add Numba JIT acceleration for nearest-neighbor VCE (260x speedup)#11

Open
sankalpsharmaa wants to merge 2 commits intordpackages:masterfrom
sankalpsharmaa:numba-acceleration
Open

Add Numba JIT acceleration for nearest-neighbor VCE (260x speedup)#11
sankalpsharmaa wants to merge 2 commits intordpackages:masterfrom
sankalpsharmaa:numba-acceleration

Conversation

@sankalpsharmaa
Copy link

@sankalpsharmaa sankalpsharmaa commented Mar 13, 2026

Summary

  • JIT-compiles the nearest-neighbor residual computation in rdrobust_res using Numba, eliminating the pure Python for pos in range(n) loop that is the dominant bottleneck
  • Numba is an optional dependency: without it, behavior is completely unchanged
  • Adds parallel execution via prange for datasets with n > 50,000

Motivation

The NN VCE path in rdrobust_res (funs.py:294-319) contains a Python-level loop over all observations, each with an inner while loop for neighbor matching. This function is called 7-19 times per rdrobust() invocation (via rdrobust_bw in bandwidth selection). For n > 50K, this dominates total runtime.

Changes

File Change
_numba_core.py (new) Serial + parallel @njit functions that replace the NN loop
funs.py Import guard + dispatch to Numba at top of vce=="nn" branch
benchmark.py (new) Correctness verification + speed benchmarks

Performance (Apple M3, Python 3.11, Numba 0.64)

rdrobust_res NN path (left side):

n Python (s) Numba (s) Speedup
5,000 0.025 0.0001 260x
20,000 0.099 0.0004 258x
50,000 0.241 0.0009 260x
500,000 - 0.0023 -

Full rdrobust() end-to-end:

n Python (s) Numba (s) Speedup
10,000 0.125 0.009 14x
50,000 0.548 0.032 17x
100,000 0.979 0.069 14x
500,000 ~5 (est.) 0.349 ~14x

Numerical difference: exactly zero (verified at n=5K and n=50K, all coefficients, SEs, and CIs match to machine precision).

Design decisions

  • Numba over C++/Cython: The bottleneck is a scalar loop over numeric arrays, which is Numba's sweet spot. No build toolchain required, no platform-specific wheels.
  • Unified kernel: Stacks [y, T, Z] into a single 2D array and loops over columns, avoiding separate functions for sharp/fuzzy/covariate cases.
  • Graceful fallback: try/except ImportError around the Numba import means zero impact on users without Numba installed.

Test plan

  • Numerical equivalence verified against pure Python at multiple n
  • Sharp RD (y only), the default and most common case
  • Fuzzy RD (y + T) - same code path, needs test
  • With covariates (y + T + Z) - same code path, needs test
  • Cluster-robust VCE (not accelerated, should be unaffected)

The nearest-neighbor variance estimator in rdrobust_res contains a pure
Python for-loop over all n observations, each with an inner while-loop
for neighbor matching. This is the dominant bottleneck, called 7-19 times
per rdrobust() invocation.

This commit adds optional Numba JIT compilation for this hot path:

- New _numba_core.py with serial and parallel (prange) JIT functions
- funs.py dispatches to Numba when available, falls back to original
  Python code when numba is not installed
- Stacks [y, T, Z] into a contiguous 2D array for the JIT kernel

Performance (Apple M3, Python 3.11, numba 0.64):
- rdrobust_res NN path: 260x faster (0.24s -> 0.0009s at n=50K)
- Full rdrobust() end-to-end: 14-17x faster
- Numerical difference vs pure Python: exactly zero

Numba is an optional dependency. Without it, behavior is unchanged.
Copilot AI review requested due to automatic review settings March 13, 2026 11:08
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an optional Numba-accelerated implementation of the nearest-neighbor (NN) VCE residual computation in rdrobust_res, aiming to remove a Python-level O(n) hot loop and optionally enable parallel execution for large n.

Changes:

  • Introduces Numba JIT kernels (serial + prange parallel) for NN residual computation.
  • Adds an import guard and runtime dispatch in rdrobust_res to use the JIT kernels when available.
  • Adds a standalone benchmarking script to compare correctness/performance vs the pure-Python path.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 7 comments.

File Description
Python/rdrobust/src/rdrobust/_numba_core.py New Numba-compiled NN residual kernels (serial + parallel) and a parallelization threshold constant.
Python/rdrobust/src/rdrobust/funs.py Adds optional Numba import and dispatch in the vce=="nn" branch of rdrobust_res.
benchmark.py New benchmark/correctness script to compare Numba vs pure-Python behavior and speed.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +14 to +16
@njit(cache=True)
def nn_residuals(X, D, matches, dups, dupsid, n, ncols):
"""
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in b22b55e. Removed n and ncols from both JIT signatures. They're now derived from X.shape[0] and D.shape[1] at the top of each kernel.

Comment on lines +305 to +311
X_flat = np.ascontiguousarray(np.asarray(X).ravel(), dtype=np.float64)
D = np.ascontiguousarray(np.asarray(y).reshape(-1, 1), dtype=np.float64)
if T is not None:
D = np.column_stack((D, np.asarray(T).ravel()))
if Z is not None:
D = np.column_stack((D, np.asarray(Z)))
D = np.ascontiguousarray(D, dtype=np.float64)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The explicit float64 cast is intentional for Numba JIT type stability and cache reuse (avoids recompilation per dtype). In practice rdrobust's internal pipeline always produces float64 arrays, so the two paths are equivalent. Added a comment at the cast site explaining the rationale.

benchmark.py Outdated
Comment on lines +110 to +115
print("Warming up Numba JIT...")
x_warm, y_warm = generate_rd_data(1000, seed=0)
sides = prepare_nn_inputs(x_warm, y_warm)
X_s, Y_s, dups, dupsid, n_s = sides[0]
_ = rdrobust_res(X_s, Y_s, None, None, 0, 0, "nn", 3, dups, dupsid, 2)
print("JIT warm-up done.\n")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in b22b55e. Warm-up block is now guarded behind if has_numba:, so it only runs (and prints) when Numba is actually available.

PARALLEL_THRESHOLD,
)
_HAS_NUMBA = True
except ImportError:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in b22b55e. Import is now lazy via _try_numba(), which runs once on the first vce=="nn" call. Also catches broad Exception instead of just ImportError to handle LLVM/runtime failures gracefully.

Comment on lines +314 to +317
ncols = D.shape[1]
if n > PARALLEL_THRESHOLD:
return _nn_res_parallel(X_flat, D, matches, dups_i, dupsid_i, n, ncols)
return _nn_res_jit(X_flat, D, matches, dups_i, dupsid_i, n, ncols)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in b22b55e. n and ncols are no longer passed to the kernels. Both are derived from the array shapes inside the JIT functions, so mismatches are impossible.

def generate_rd_data(n, seed=42):
"""Generate synthetic sharp RD data."""
rng = np.random.default_rng(seed)
x = rng.uniform(-1, 1, n)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in b22b55e. Added generate_rd_data_with_masspoints() which rounds X to 1 decimal place, creating duplicate running-variable values. This is now correctness check #2 in the benchmark.

benchmark.py Outdated
Comment on lines +117 to +134
# ── Correctness check ──
print("=" * 60)
print("CORRECTNESS CHECK (n=5,000)")
print("=" * 60)
x_check, y_check = generate_rd_data(5000, seed=1)
sides = prepare_nn_inputs(x_check, y_check)
X_s, Y_s, dups, dupsid, n_s = sides[0]

_, res_python = time_pure_python_res(X_s, Y_s, dups, dupsid, n_s)
_, res_numba = time_rdrobust_res(X_s, Y_s, dups, dupsid, n_s)

max_diff = np.max(np.abs(res_python - res_numba))
print(f"Max absolute difference: {max_diff:.2e}")
if max_diff < 1e-10:
print("PASS: Results are numerically identical.\n")
else:
print("FAIL: Results differ!\n")
return
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in b22b55e. Correctness checks now cover all four cases:

  1. Sharp RD, continuous X
  2. Sharp RD, mass points (duplicates)
  3. Fuzzy RD (y + T)
  4. With covariates (y + T + Z)

All pass with zero numerical difference.

- Remove redundant n/ncols params from JIT signatures; derive from
  X.shape[0] and D.shape[1] inside the kernels
- Make Numba import lazy (deferred to first vce=="nn" call) and catch
  broad Exception instead of just ImportError (handles LLVM/runtime)
- Add comment explaining float64 cast rationale in Numba dispatch path
- Guard JIT warm-up behind _HAS_NUMBA check in benchmark
- Add mass-point correctness test (rounded X with duplicates)
- Add fuzzy RD (y + T) and covariate (y + T + Z) correctness checks
- Extend prepare_nn_inputs to pass T and Z through to each side
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants